{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 基于自注意力机制的语言模型\n",
    "\n",
    "先做一个简单地拼音映射为汉字的实验。\n",
    "\n",
    "## 1. 数据处理\n",
    "- 读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"data/zh.tsv\", 'r', encoding='utf-8') as fout:\n",
    "    data = fout.readlines()[:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 50129.13it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "inputs = []\n",
    "labels = []\n",
    "for i in tqdm(range(len(data))):\n",
    "    key, pny, hanzi = data[i].split('\\t')\n",
    "    inputs.append(pny.split(' '))\n",
    "    labels.append(hanzi.strip('\\n').split(' '))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[['lv4', 'shi4', 'yang2', 'chun1', 'yan1', 'jing3', 'da4', 'kuai4', 'wen2', 'zhang1', 'de', 'di3', 'se4', 'si4', 'yue4', 'de', 'lin2', 'luan2', 'geng4', 'shi4', 'lv4', 'de2', 'xian1', 'huo2', 'xiu4', 'mei4', 'shi1', 'yi4', 'ang4', 'ran2'], ['ta1', 'jin3', 'ping2', 'yao1', 'bu4', 'de', 'li4', 'liang4', 'zai4', 'yong3', 'dao4', 'shang4', 'xia4', 'fan1', 'teng2', 'yong3', 'dong4', 'she2', 'xing2', 'zhuang4', 'ru2', 'hai3', 'tun2', 'yi1', 'zhi2', 'yi3', 'yi1', 'tou2', 'de', 'you1', 'shi4', 'ling3', 'xian1'], ['pao4', 'yan3', 'da3', 'hao3', 'le', 'zha4', 'yao4', 'zen3', 'me', 'zhuang1', 'yue4', 'zheng4', 'cai2', 'yao3', 'le', 'yao3', 'ya2', 'shu1', 'de', 'tuo1', 'qu4', 'yi1', 'fu2', 'guang1', 'bang3', 'zi', 'chong1', 'jin4', 'le', 'shui3', 'cuan4', 'dong4'], ['ke3', 'shei2', 'zhi1', 'wen2', 'wan2', 'hou4', 'ta1', 'yi1', 'zhao4', 'jing4', 'zi', 'zhi3', 'jian4', 'zuo3', 'xia4', 'yan3', 'jian3', 'de', 'xian4', 'you4', 'cu1', 'you4', 'hei1', 'yu3', 'you4', 'ce4', 'ming2', 'xian3', 'bu4', 'dui4', 'cheng1'], ['qi1', 'shi2', 'nian2', 'dai4', 'mo4', 'wo3', 'wai4', 'chu1', 'qiu2', 'xue2', 'mu3', 'qin1', 'ding1', 'ning2', 'wo3', 'chi1', 'fan4', 'yao4', 'xi4', 'jue2', 'man4', 'yan4', 'xue2', 'xi2', 'yao4', 'shen1', 'zuan1', 'xi4', 'yan2']]\n",
      "\n",
      "[['绿', '是', '阳', '春', '烟', '景', '大', '块', '文', '章', '的', '底', '色', '四', '月', '的', '林', '峦', '更', '是', '绿', '得', '鲜', '活', '秀', '媚', '诗', '意', '盎', '然'], ['他', '仅', '凭', '腰', '部', '的', '力', '量', '在', '泳', '道', '上', '下', '翻', '腾', '蛹', '动', '蛇', '行', '状', '如', '海', '豚', '一', '直', '以', '一', '头', '的', '优', '势', '领', '先'], ['炮', '眼', '打', '好', '了', '炸', '药', '怎', '么', '装', '岳', '正', '才', '咬', '了', '咬', '牙', '倏', '地', '脱', '去', '衣', '服', '光', '膀', '子', '冲', '进', '了', '水', '窜', '洞'], ['可', '谁', '知', '纹', '完', '后', '她', '一', '照', '镜', '子', '只', '见', '左', '下', '眼', '睑', '的', '线', '又', '粗', '又', '黑', '与', '右', '侧', '明', '显', '不', '对', '称'], ['七', '十', '年', '代', '末', '我', '外', '出', '求', '学', '母', '亲', '叮', '咛', '我', '吃', '饭', '要', '细', '嚼', '慢', '咽', '学', '习', '要', '深', '钻', '细', '研']]\n"
     ]
    }
   ],
   "source": [
    "print(inputs[:5])\n",
    "print()\n",
    "print(labels[:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 7712.39it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 5277.65it/s]\n"
     ]
    }
   ],
   "source": [
    "def get_vocab(data):\n",
    "    vocab = ['<PAD>']\n",
    "    for line in tqdm(data):\n",
    "        for char in line:\n",
    "            if char not in vocab:\n",
    "                vocab.append(char)\n",
    "    return vocab\n",
    "\n",
    "pny2id = get_vocab(inputs)\n",
    "han2id = get_vocab(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['<PAD>', 'lv4', 'shi4', 'yang2', 'chun1', 'yan1', 'jing3', 'da4', 'kuai4', 'wen2']\n",
      "['<PAD>', '绿', '是', '阳', '春', '烟', '景', '大', '块', '文']\n"
     ]
    }
   ],
   "source": [
    "print(pny2id[:10])\n",
    "print(han2id[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 6591.81it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3855.73it/s]\n"
     ]
    }
   ],
   "source": [
    "input_num = [[pny2id.index(pny) for pny in line] for line in tqdm(inputs)]\n",
    "label_num = [[han2id.index(han) for han in line] for line in tqdm(labels)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  11  16  17\n",
      "   18   2   1  19  20  21  22  23  24  25  26  27   0   0   0]\n",
      " [ 28  29  30  31  32  11  33  34  35  36  37  38  39  40  41  36  42  43\n",
      "   44  45  46  47  48  49  50  51  49  52  11  53   2  54  20]\n",
      " [ 55  56  57  58  59  60  61  62  63  64  15  65  66  67  59  67  68  69\n",
      "   11  70  71  49  72  73  74  75  76  77  59  78  79  42   0]\n",
      " [ 80  81  82   9  83  84  28  49  85  86  75  87  88  89  39  56  90  11\n",
      "   91  92  93  92  94  95  92  96  97  98  32  99 100   0   0]]\n",
      "[[  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  11  16  17\n",
      "   18   2   1  19  20  21  22  23  24  25  26  27   0   0   0]\n",
      " [ 28  29  30  31  32  11  33  34  35  36  37  38  39  40  41  42  43  44\n",
      "   45  46  47  48  49  50  51  52  50  53  11  54  55  56  57]\n",
      " [ 58  59  60  61  62  63  64  65  66  67  68  69  70  71  62  71  72  73\n",
      "   74  75  76  77  78  79  80  81  82  83  62  84  85  86   0]\n",
      " [ 87  88  89  90  91  92  93  50  94  95  81  96  97  98  39  59  99  11\n",
      "  100 101 102 101 103 104 105 106 107 108 109 110 111   0   0]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "def get_batch(input_data, label_data, batch_size):\n",
    "    batch_num = len(input_data) // batch_size\n",
    "    for k in range(batch_num):\n",
    "        begin = k * batch_size\n",
    "        end = begin + batch_size\n",
    "        input_batch = input_data[begin:end]\n",
    "        label_batch = label_data[begin:end]\n",
    "        max_len = max([len(line) for line in input_batch])\n",
    "        input_batch = np.array([line + [0] * (max_len - len(line)) for line in input_batch])\n",
    "        label_batch = np.array([line + [0] * (max_len - len(line)) for line in label_batch])\n",
    "        yield input_batch, label_batch\n",
    "        \n",
    "        \n",
    "batch = get_batch(input_num, label_num, 4)\n",
    "input_batch, label_batch = next(batch)\n",
    "print(input_batch)\n",
    "print(label_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.模型搭建\n",
    "模型采用self-attention,模型结构如下：\n",
    "\n",
    "![image.png](attachment:image.png)\n",
    "\n",
    "我们只需要搭建左侧编码器即可，不用搭建右侧解码器。\n",
    "\n",
    "论文地址：https://arxiv.org/abs/1706.03762\n",
    "\n",
    "模型代码搭建我们直接参考开源的代码：\n",
    "https://github.com/Kyubyong/transformer/blob/master/modules.py\n",
    "\n",
    "我们只需要注意每一快的输入输出数据形式怎样即可。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.1 构造建模组件"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面代码实现了图片结构中的各个功能组件。\n",
    "![image.png](attachment:image.png)\n",
    "\n",
    "#### layer norm层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize(inputs, \n",
    "              epsilon = 1e-8,\n",
    "              scope=\"ln\",\n",
    "              reuse=None):\n",
    "    '''Applies layer normalization.\n",
    "\n",
    "    Args:\n",
    "      inputs: A tensor with 2 or more dimensions, where the first dimension has\n",
    "        `batch_size`.\n",
    "      epsilon: A floating number. A very small number for preventing ZeroDivision Error.\n",
    "      scope: Optional scope for `variable_scope`.\n",
    "      reuse: Boolean, whether to reuse the weights of a previous layer\n",
    "        by the same name.\n",
    "\n",
    "    Returns:\n",
    "      A tensor with the same shape and data dtype as `inputs`.\n",
    "    '''\n",
    "    with tf.variable_scope(scope, reuse=reuse):\n",
    "        inputs_shape = inputs.get_shape()\n",
    "        params_shape = inputs_shape[-1:]\n",
    "\n",
    "        mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)\n",
    "        beta= tf.Variable(tf.zeros(params_shape))\n",
    "        gamma = tf.Variable(tf.ones(params_shape))\n",
    "        normalized = (inputs - mean) / ( (variance + epsilon) ** (.5) )\n",
    "        outputs = gamma * normalized + beta\n",
    "\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### embedding层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def embedding(inputs, \n",
    "              vocab_size, \n",
    "              num_units, \n",
    "              zero_pad=True, \n",
    "              scale=True,\n",
    "              scope=\"embedding\", \n",
    "              reuse=None):\n",
    "    '''Embeds a given tensor.\n",
    "    Args:\n",
    "      inputs: A `Tensor` with type `int32` or `int64` containing the ids\n",
    "         to be looked up in `lookup table`.\n",
    "      vocab_size: An int. Vocabulary size.\n",
    "      num_units: An int. Number of embedding hidden units.\n",
    "      zero_pad: A boolean. If True, all the values of the fist row (id 0)\n",
    "        should be constant zeros.\n",
    "      scale: A boolean. If True. the outputs is multiplied by sqrt num_units.\n",
    "      scope: Optional scope for `variable_scope`.\n",
    "      reuse: Boolean, whether to reuse the weights of a previous layer\n",
    "        by the same name.\n",
    "    Returns:\n",
    "      A `Tensor` with one more rank than inputs's. The last dimensionality\n",
    "        should be `num_units`.\n",
    "\n",
    "    For example,\n",
    "\n",
    "    ```\n",
    "    import tensorflow as tf\n",
    "\n",
    "    inputs = tf.to_int32(tf.reshape(tf.range(2*3), (2, 3)))\n",
    "    outputs = embedding(inputs, 6, 2, zero_pad=True)\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        print sess.run(outputs)\n",
    "    >>\n",
    "    [[[ 0.          0.        ]\n",
    "      [ 0.09754146  0.67385566]\n",
    "      [ 0.37864095 -0.35689294]]\n",
    "     [[-1.01329422 -1.09939694]\n",
    "      [ 0.7521342   0.38203377]\n",
    "      [-0.04973143 -0.06210355]]]\n",
    "    ```\n",
    "\n",
    "    ```\n",
    "    import tensorflow as tf\n",
    "\n",
    "    inputs = tf.to_int32(tf.reshape(tf.range(2*3), (2, 3)))\n",
    "    outputs = embedding(inputs, 6, 2, zero_pad=False)\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        print sess.run(outputs)\n",
    "    >>\n",
    "    [[[-0.19172323 -0.39159766]\n",
    "      [-0.43212751 -0.66207761]\n",
    "      [ 1.03452027 -0.26704335]]\n",
    "     [[-0.11634696 -0.35983452]\n",
    "      [ 0.50208133  0.53509563]\n",
    "      [ 1.22204471 -0.96587461]]]\n",
    "    ```    \n",
    "    '''\n",
    "    with tf.variable_scope(scope, reuse=reuse):\n",
    "        lookup_table = tf.get_variable('lookup_table',\n",
    "                                       dtype=tf.float32,\n",
    "                                       shape=[vocab_size, num_units],\n",
    "                                       initializer=tf.contrib.layers.xavier_initializer())\n",
    "        if zero_pad:\n",
    "            lookup_table = tf.concat((tf.zeros(shape=[1, num_units]),\n",
    "                                      lookup_table[1:, :]), 0)\n",
    "        outputs = tf.nn.embedding_lookup(lookup_table, inputs)\n",
    "\n",
    "        if scale:\n",
    "            outputs = outputs * (num_units ** 0.5) \n",
    "\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### multihead层\n",
    "该层实现了下面功能：\n",
    "![image.png](attachment:image.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def multihead_attention(emb,\n",
    "                        queries, \n",
    "                        keys, \n",
    "                        num_units=None, \n",
    "                        num_heads=8, \n",
    "                        dropout_rate=0,\n",
    "                        is_training=True,\n",
    "                        causality=False,\n",
    "                        scope=\"multihead_attention\", \n",
    "                        reuse=None):\n",
    "    '''Applies multihead attention.\n",
    "    \n",
    "    Args:\n",
    "      queries: A 3d tensor with shape of [N, T_q, C_q].\n",
    "      keys: A 3d tensor with shape of [N, T_k, C_k].\n",
    "      num_units: A scalar. Attention size.\n",
    "      dropout_rate: A floating point number.\n",
    "      is_training: Boolean. Controller of mechanism for dropout.\n",
    "      causality: Boolean. If true, units that reference the future are masked. \n",
    "      num_heads: An int. Number of heads.\n",
    "      scope: Optional scope for `variable_scope`.\n",
    "      reuse: Boolean, whether to reuse the weights of a previous layer\n",
    "        by the same name.\n",
    "        \n",
    "    Returns\n",
    "      A 3d tensor with shape of (N, T_q, C)  \n",
    "    '''\n",
    "    with tf.variable_scope(scope, reuse=reuse):\n",
    "        # Set the fall back option for num_units\n",
    "        if num_units is None:\n",
    "            num_units = queries.get_shape().as_list[-1]\n",
    "        \n",
    "        # Linear projections\n",
    "        Q = tf.layers.dense(queries, num_units, activation=tf.nn.relu) # (N, T_q, C)\n",
    "        K = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C)\n",
    "        V = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C)\n",
    "        \n",
    "        # Split and concat\n",
    "        Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, C/h) \n",
    "        K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) \n",
    "        V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) \n",
    "\n",
    "        # Multiplication\n",
    "        outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # (h*N, T_q, T_k)\n",
    "        \n",
    "        # Scale\n",
    "        outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5)\n",
    "        \n",
    "        # Key Masking\n",
    "        key_masks = tf.sign(tf.abs(tf.reduce_sum(emb, axis=-1))) # (N, T_k)\n",
    "        key_masks = tf.tile(key_masks, [num_heads, 1]) # (h*N, T_k)\n",
    "        key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(queries)[1], 1]) # (h*N, T_q, T_k)\n",
    "        \n",
    "        paddings = tf.ones_like(outputs)*(-2**32+1)\n",
    "        outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs) # (h*N, T_q, T_k)\n",
    "  \n",
    "        # Causality = Future blinding\n",
    "        if causality:\n",
    "            diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k)\n",
    "            tril = tf.contrib.linalg.LinearOperatorTriL(diag_vals).to_dense() # (T_q, T_k)\n",
    "            masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k)\n",
    "   \n",
    "            paddings = tf.ones_like(masks)*(-2**32+1)\n",
    "            outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k)\n",
    "  \n",
    "        # Activation\n",
    "        outputs = tf.nn.softmax(outputs) # (h*N, T_q, T_k)\n",
    "         \n",
    "        # Query Masking\n",
    "        query_masks = tf.sign(tf.abs(tf.reduce_sum(emb, axis=-1))) # (N, T_q)\n",
    "        query_masks = tf.tile(query_masks, [num_heads, 1]) # (h*N, T_q)\n",
    "        query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]]) # (h*N, T_q, T_k)\n",
    "        outputs *= query_masks # broadcasting. (N, T_q, C)\n",
    "          \n",
    "        # Dropouts\n",
    "        outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=tf.convert_to_tensor(is_training))\n",
    "               \n",
    "        # Weighted sum\n",
    "        outputs = tf.matmul(outputs, V_) # ( h*N, T_q, C/h)\n",
    "        \n",
    "        # Restore shape\n",
    "        outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2 ) # (N, T_q, C)\n",
    "              \n",
    "        # Residual connection\n",
    "        outputs += queries\n",
    "              \n",
    "        # Normalize\n",
    "        outputs = normalize(outputs) # (N, T_q, C)\n",
    " \n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### feedforward\n",
    "\n",
    "两层全连接，用卷积模拟加速运算，也可以使用dense层。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def feedforward(inputs, \n",
    "                num_units=[2048, 512],\n",
    "                scope=\"multihead_attention\", \n",
    "                reuse=None):\n",
    "    '''Point-wise feed forward net.\n",
    "    \n",
    "    Args:\n",
    "      inputs: A 3d tensor with shape of [N, T, C].\n",
    "      num_units: A list of two integers.\n",
    "      scope: Optional scope for `variable_scope`.\n",
    "      reuse: Boolean, whether to reuse the weights of a previous layer\n",
    "        by the same name.\n",
    "        \n",
    "    Returns:\n",
    "      A 3d tensor with the same shape and dtype as inputs\n",
    "    '''\n",
    "    with tf.variable_scope(scope, reuse=reuse):\n",
    "        # Inner layer\n",
    "        params = {\"inputs\": inputs, \"filters\": num_units[0], \"kernel_size\": 1,\n",
    "                  \"activation\": tf.nn.relu, \"use_bias\": True}\n",
    "        outputs = tf.layers.conv1d(**params)\n",
    "        \n",
    "        # Readout layer\n",
    "        params = {\"inputs\": outputs, \"filters\": num_units[1], \"kernel_size\": 1,\n",
    "                  \"activation\": None, \"use_bias\": True}\n",
    "        outputs = tf.layers.conv1d(**params)\n",
    "        \n",
    "        # Residual connection\n",
    "        outputs += inputs\n",
    "        \n",
    "        # Normalize\n",
    "        outputs = normalize(outputs)\n",
    "    \n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### label_smoothing.\n",
    "对于训练有好处，将0变为接近零的小数，1变为接近1的数，原文：\n",
    "\n",
    "During training, we employed label smoothing of value \u000fls = 0.1 [36]. This hurts perplexity, as the model learns to be more unsure, but improves accuracy and BLEU score."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def label_smoothing(inputs, epsilon=0.1):\n",
    "    '''Applies label smoothing. See https://arxiv.org/abs/1512.00567.\n",
    "    \n",
    "    Args:\n",
    "      inputs: A 3d tensor with shape of [N, T, V], where V is the number of vocabulary.\n",
    "      epsilon: Smoothing rate.\n",
    "    \n",
    "    For example,\n",
    "    \n",
    "    ```\n",
    "    import tensorflow as tf\n",
    "    inputs = tf.convert_to_tensor([[[0, 0, 1], \n",
    "       [0, 1, 0],\n",
    "       [1, 0, 0]],\n",
    "      [[1, 0, 0],\n",
    "       [1, 0, 0],\n",
    "       [0, 1, 0]]], tf.float32)\n",
    "       \n",
    "    outputs = label_smoothing(inputs)\n",
    "    \n",
    "    with tf.Session() as sess:\n",
    "        print(sess.run([outputs]))\n",
    "    \n",
    "    >>\n",
    "    [array([[[ 0.03333334,  0.03333334,  0.93333334],\n",
    "        [ 0.03333334,  0.93333334,  0.03333334],\n",
    "        [ 0.93333334,  0.03333334,  0.03333334]],\n",
    "       [[ 0.93333334,  0.03333334,  0.03333334],\n",
    "        [ 0.93333334,  0.03333334,  0.03333334],\n",
    "        [ 0.03333334,  0.93333334,  0.03333334]]], dtype=float32)]   \n",
    "    ```    \n",
    "    '''\n",
    "    K = inputs.get_shape().as_list()[-1] # number of channels\n",
    "    return ((1-epsilon) * inputs) + (epsilon / K)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 搭建模型\n",
    "模型实现下图结构：\n",
    "![image.png](attachment:image.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Graph():\n",
    "    def __init__(self, is_training=True):\n",
    "        tf.reset_default_graph()\n",
    "        self.is_training = arg.is_training\n",
    "        self.hidden_units = arg.hidden_units\n",
    "        self.input_vocab_size = arg.input_vocab_size\n",
    "        self.label_vocab_size = arg.label_vocab_size\n",
    "        self.num_heads = arg.num_heads\n",
    "        self.num_blocks = arg.num_blocks\n",
    "        self.max_length = arg.max_length\n",
    "        self.lr = arg.lr\n",
    "        self.dropout_rate = arg.dropout_rate\n",
    "        \n",
    "        # input\n",
    "        self.x = tf.placeholder(tf.int32, shape=(None, None))\n",
    "        self.y = tf.placeholder(tf.int32, shape=(None, None))\n",
    "        # embedding\n",
    "        self.emb = embedding(self.x, vocab_size=self.input_vocab_size, num_units=self.hidden_units, scale=True, scope=\"enc_embed\")\n",
    "        self.enc = self.emb + embedding(tf.tile(tf.expand_dims(tf.range(tf.shape(self.x)[1]), 0), [tf.shape(self.x)[0], 1]),\n",
    "                                      vocab_size=self.max_length,num_units=self.hidden_units, zero_pad=False, scale=False,scope=\"enc_pe\")\n",
    "        ## Dropout\n",
    "        self.enc = tf.layers.dropout(self.enc, \n",
    "                                    rate=self.dropout_rate, \n",
    "                                    training=tf.convert_to_tensor(self.is_training))\n",
    "                \n",
    "        ## Blocks\n",
    "        for i in range(self.num_blocks):\n",
    "            with tf.variable_scope(\"num_blocks_{}\".format(i)):\n",
    "                ### Multihead Attention\n",
    "                self.enc = multihead_attention(emb = self.emb,\n",
    "                                               queries=self.enc, \n",
    "                                                keys=self.enc, \n",
    "                                                num_units=self.hidden_units, \n",
    "                                                num_heads=self.num_heads, \n",
    "                                                dropout_rate=self.dropout_rate,\n",
    "                                                is_training=self.is_training,\n",
    "                                                causality=False)\n",
    "                        \n",
    "        ### Feed Forward\n",
    "        self.outputs = feedforward(self.enc, num_units=[4*self.hidden_units, self.hidden_units])\n",
    "            \n",
    "                \n",
    "        # Final linear projection\n",
    "        self.logits = tf.layers.dense(self.outputs, self.label_vocab_size)\n",
    "        self.preds = tf.to_int32(tf.argmax(self.logits, axis=-1))\n",
    "        self.istarget = tf.to_float(tf.not_equal(self.y, 0))\n",
    "        self.acc = tf.reduce_sum(tf.to_float(tf.equal(self.preds, self.y))*self.istarget)/ (tf.reduce_sum(self.istarget))\n",
    "        tf.summary.scalar('acc', self.acc)\n",
    "                \n",
    "        if is_training:  \n",
    "            # Loss\n",
    "            self.y_smoothed = label_smoothing(tf.one_hot(self.y, depth=self.label_vocab_size))\n",
    "            self.loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.y_smoothed)\n",
    "            self.mean_loss = tf.reduce_sum(self.loss*self.istarget) / (tf.reduce_sum(self.istarget))\n",
    "               \n",
    "            # Training Scheme\n",
    "            self.global_step = tf.Variable(0, name='global_step', trainable=False)\n",
    "            self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.9, beta2=0.98, epsilon=1e-8)\n",
    "            self.train_op = self.optimizer.minimize(self.mean_loss, global_step=self.global_step)\n",
    "                   \n",
    "            # Summary \n",
    "            tf.summary.scalar('mean_loss', self.mean_loss)\n",
    "            self.merged = tf.summary.merge_all()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 训练模型\n",
    "### 3.1 参数设定"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_hparams():\n",
    "    params = tf.contrib.training.HParams(\n",
    "        num_heads = 8,\n",
    "        num_blocks = 6,\n",
    "        # vocab\n",
    "        input_vocab_size = 50,\n",
    "        label_vocab_size = 50,\n",
    "        # embedding size\n",
    "        max_length = 100,\n",
    "        hidden_units = 512,\n",
    "        dropout_rate = 0.2,\n",
    "        lr = 0.0003,\n",
    "        is_training = True)\n",
    "    return params\n",
    "\n",
    "        \n",
    "arg = create_hparams()\n",
    "arg.input_vocab_size = len(pny2id)\n",
    "arg.label_vocab_size = len(han2id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From d:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\util\\deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "`NHWC` for data_format is deprecated, use `NWC` instead\n",
      "WARNING:tensorflow:From <ipython-input-14-862b249aa009>:53: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "\n",
      "Future major versions of TensorFlow will allow gradients to flow\n",
      "into the labels input on backprop by default.\n",
      "\n",
      "See tf.nn.softmax_cross_entropy_with_logits_v2.\n",
      "\n",
      "epochs 5 : average loss =  1.6399681091308593\n",
      "epochs 10 : average loss =  1.1646613264083863\n",
      "epochs 15 : average loss =  1.157580156326294\n",
      "epochs 20 : average loss =  1.1407248640060426\n",
      "epochs 25 : average loss =  1.1298853492736816\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "epochs = 25\n",
    "batch_size = 4\n",
    "\n",
    "g = Graph(arg)\n",
    "\n",
    "saver =tf.train.Saver()\n",
    "with tf.Session() as sess:\n",
    "    merged = tf.summary.merge_all()\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    if os.path.exists('logs/model.meta'):\n",
    "        saver.restore(sess, 'logs/model')\n",
    "    writer = tf.summary.FileWriter('tensorboard/lm', tf.get_default_graph())\n",
    "    for k in range(epochs):\n",
    "        total_loss = 0\n",
    "        batch_num = len(input_num) // batch_size\n",
    "        batch = get_batch(input_num, label_num, batch_size)\n",
    "        for i in range(batch_num):\n",
    "            input_batch, label_batch = next(batch)\n",
    "            feed = {g.x: input_batch, g.y: label_batch}\n",
    "            cost,_ = sess.run([g.mean_loss,g.train_op], feed_dict=feed)\n",
    "            total_loss += cost\n",
    "            if (k * batch_num + i) % 10 == 0:\n",
    "                rs=sess.run(merged, feed_dict=feed)\n",
    "                writer.add_summary(rs, k * batch_num + i)\n",
    "        if (k+1) % 5 == 0:\n",
    "            print('epochs', k+1, ': average loss = ', total_loss/batch_num)\n",
    "    saver.save(sess, 'logs/model')\n",
    "    writer.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3 模型推断"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from logs/model\n",
      "输入测试拼音: shen1 ye4 shi2 er4 dian3 zhong1 ta1 zhan4 zai4 shao4 wei4 shang4 huan2 shi4 zhou1 wei2 yin1 sen1 ke3 bu4 yue4 xiang3 yue4 hai4 pa4 bao4 qiang1 ku1 zhe pao3 hui2 ying2 fang2\n",
      "深夜十二点钟他站在哨位上环视周围阴森可怖越想越害怕抱枪哭着跑回营房\n",
      "输入测试拼音: wu3 yue4 er4 shi2 jiu3 ri4 ye4 wan3 ao4 da4 li4 ya4 shou3 dou1 kan1 pei2 la1 de huang2 jia1 ju4 yuan4 re4 lie4 er2 chong1 man3 zhe zhen1 qing2\n",
      "五月二十九日夜晚澳大利亚首都堪培拉的皇家剧院热烈而充满着真情\n",
      "输入测试拼音: tai4 hu2 dong1 an4 yi1 dai4 de yun2 tuan2 you2 dan4 dao4 nong2 zai4 you2 nong2 dao4 dan4 er2 shang4 hai3 shi4 qu1 shang4 kong1 de yun2 yue4 lai2 yue4 shao3 yue4 lai2 yue4 xi1\n",
      "太湖东岸一带的云团由淡到浓再由浓到淡而上海市区上空的云越来越少越来越稀\n",
      "输入测试拼音: exit\n"
     ]
    }
   ],
   "source": [
    "arg.is_training = False\n",
    "\n",
    "g = Graph(arg)\n",
    "\n",
    "saver =tf.train.Saver()\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    saver.restore(sess, 'logs/model')\n",
    "    while True:\n",
    "        line = input('输入测试拼音: ')\n",
    "        if line == 'exit': break\n",
    "        line = line.strip('\\n').split(' ')\n",
    "        x = np.array([pny2id.index(pny) for pny in line])\n",
    "        x = x.reshape(1, -1)\n",
    "        preds = sess.run(g.preds, {g.x: x})\n",
    "        got = ''.join(han2id[idx] for idx in preds[0])\n",
    "        print(got)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
